package com.jl.thread;

import cn.hutool.core.map.MapUtil;
import com.jl.springbean.util.JLSpringBean;
import com.jl.thread.config.JLThreadConfig;
import lombok.AllArgsConstructor;
import org.springframework.context.ApplicationContext;
import org.springframework.context.annotation.Configuration;
import org.springframework.scheduling.annotation.EnableAsync;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;

import javax.annotation.PostConstruct;
import java.lang.reflect.Method;
import java.util.*;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ThreadPoolExecutor;

/**
 * springboot线程池
 */
@EnableAsync
@Configuration
@AllArgsConstructor
public class JLThread {

    private JLThreadConfig threadConfig;

    private JLSpringBean jlSpringBean;

    private ApplicationContext applicationContext;

    /**
     * 注册线程池
     */
    @PostConstruct
    private void register() {
        List<JLThreadConfig.Param> params = threadConfig.getParam();
        JLThreadConfig.Param jlThread = new JLThreadConfig.Param().setCore(2).setMax(8).setQueue(40).setName("jl-thread").setKeep(30);
        if (params != null) {
            params.add(jlThread);
        } else {
            params = Arrays.asList(jlThread);
        }
        params.forEach(param -> {
            Map<String, Object> map = MapUtil.builder(new HashMap<String, Object>(6))
                    .put("corePoolSize", param.getCore() == null ? 5 : param.getCore())
                    .put("maxPoolSize", param.getMax() == null ? 20 : param.getMax())
                    .put("queueCapacity", param.getQueue() == null ? 100 : param.getQueue())
                    .put("keepAliveSeconds", param.getKeep() == null ? 60 : param.getKeep())
                    .put("threadNamePrefix", param.getName())
                    .put("rejectedExecutionHandler", new ThreadPoolExecutor.CallerRunsPolicy())
                    .build();
            jlSpringBean.registerBean(param.getName(), ThreadPoolTaskExecutor.class, map);
        });
    }

    /**
     * 执行方法
     *
     * @param threadName  线程名称
     * @param beanClass   springbeanClass
     * @param beanMethod  springbean方法名
     * @param resultClass 返参泛型
     * @param params      方法参数 一个下标一个线程
     * @param <T>
     * @return 所有线程的执行结果集合并
     */
    public <T> List<T> exec(String threadName, Class<?> beanClass, String beanMethod, Class<T> resultClass, List<Object[]> params) {
        Object bean = applicationContext.getBean(beanClass);
        return exec(threadName, bean, beanMethod, resultClass, params);
    }

    /**
     * 执行方法
     *
     * @param threadName  线程名称
     * @param threadName  springbean
     * @param beanMethod  springbean方法名
     * @param resultClass 返参泛型
     * @param params      方法参数 一个下标一个线程
     * @param <T>
     * @return 所有线程的执行结果集合并
     */
    public <T> List<T> exec(String threadName, Object bean, String beanMethod, Class<T> resultClass, List<Object[]> params) {
        List<T> list = Collections.synchronizedList(new ArrayList<>());
        ThreadPoolTaskExecutor executor = applicationContext.getBean(threadName, ThreadPoolTaskExecutor.class);
        CountDownLatch latch = new CountDownLatch(params.size());
        for (Object[] param : params) {
            executor.execute(() -> {
                try {
                    Class<?>[] paramClass = new Class<?>[param.length];
                    for (int i = 0; i < param.length; i++) {
                        paramClass[i] = param[i].getClass();
                    }
                    Method method = bean.getClass().getMethod(beanMethod, paramClass);
                    list.add((T) method.invoke(bean, param));
                } catch (Exception e) {
                    e.printStackTrace();
                } finally {
                    latch.countDown();
                }
            });
        }
        try {
            latch.await();
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
        return list;
    }
}
